In [6]:
import torch as T
import numpy as np
from matplotlib import pyplot as pt

import models
import util

device = T.device('cuda:3')
In [3]:
# load data and visualize the image

img, lc, nlcd = util.load_data()

lc_oh = np.array([lc==i for i in range(4)])
nlcd_oh = np.array([nlcd==util.nlcd_cl[i] for i in range(22)])

pt.figure(figsize=(12,4))
pt.subplot(131);pt.imshow(img[:3].T)
pt.subplot(132);pt.imshow(util.vis_lc(lc_oh).T)
pt.subplot(133);pt.imshow(util.vis_nlcd(nlcd_oh).T)
pt.show()
In [7]:
# initialize model

epitome_size = 299
ep = models.EpitomeModel(epitome_size, 4).to(device)
In [8]:
# train the model (best run on GPU)
# see figure in SI for outputs

n_batches = 10000
batch_size = 256
show_interval = 100

diversify = False # tiny image, no need to
mask_threshold = float('1e-8')
reset_threshold = 0.95

optimizer = T.optim.Adam(ep.parameters(), lr=0.003)

counter = T.zeros((ep.size, ep.size)).to(device)

for it in range(n_batches):
    w = np.random.randint(10,16)*2+1 # odd number 21 to 31
    
    #construct the batches
    batch = np.zeros((batch_size, 4, w, w))
    for b in range(batch_size):
        x = np.random.randint(img.shape[1]-w+1)
        y = np.random.randint(img.shape[2]-w+1)
        batch[b] = img[:,x:x+w,y:y+w]
    x = T.from_numpy(batch).to(device, T.float)
    
    optimizer.zero_grad()
    
    # compute p(x|s)p(s) and smooth
    e = ep(x) / (w/11)**2
    
    # extract worst-modeled quarter of data
    if diversify:
        indices = e.logsumexp((0,2,3)).topk(batch_size//4, largest=False, sorted=False).indices
        e = e[:,indices]
    
    # increment counters and compute mask
    posterior = e.view(-1, ep.size*ep.size).softmax(1).view(-1, ep.size, ep.size).mean(0)
    with T.no_grad(): counter += posterior
    mask = (counter > mask_threshold).float()
    
    # reset counters if threshold reached
    if (mask.mean() > reset_threshold):
        counter[:] = 0.
        mask[:] = 0.
    
    # compute log likelihood of data over unmasked positions (+const)
    loss = -T.logsumexp(e - 10000*mask, (0,2,3))
    loss.mean().backward()
    optimizer.step()
    
    # clamp parameters
    with T.no_grad(): 
        ep.ivar[:].clamp_(min=1, max=10**2)
        ep.prior[:] -= ep.prior.mean()
        ep.prior[:].clamp_(min=-4., max=4.)

    # show the means
    if it % show_interval == 0:
        pt.imshow(ep.mean.detach().cpu().numpy()[0,:3].T)
        pt.show()
In [10]:
T.save(ep.state_dict(), 'epitome')
In [14]:
ep.load_state_dict(T.load('epitome', map_location={'cuda:2':'cuda:3'}))
Out[14]:
<All keys matched successfully>
In [15]:
max_patch_size = 31
hw = max_patch_size//2

def label_embed(labels, vis_fn, show=False):

    n_batches = 51
    n_samples = 64

    ep_map = np.zeros((labels.shape[0], ep.layers, ep.size+max_patch_size, ep.size+max_patch_size)) + 0.00001

    T.set_grad_enabled(False)

    for it in range(n_batches):

        w = np.random.randint(10,16)*2+1 #size of the patch to compute posterior for
        ew = 11 #size of the center piece of the patch from which to embed labels

        batch = np.zeros((batch_size,4,w,w))
        lc_batch = np.zeros((batch_size,labels.shape[0],w,w))
        for b in range(batch_size):
            x = np.random.randint(img.shape[1]-w+1)
            y = np.random.randint(img.shape[2]-w+1)
            batch[b] = img[:,x:x+w,y:y+w]
            lc_batch[b] = labels[:,x:x+w,y:y+w]

        x = T.from_numpy(batch).to(device, T.float)

        e = ep(x) / (w/11)**2

        temp = max(3-it,1)# take a few samples with higher temperature to fill in the gaps (for pretty picture)
        logits = e.transpose(0,1).reshape(batch_size,-1) / temp
        dist = T.distributions.Categorical(logits=logits.cpu())  

        d = (w-ew)//2
        shift = (max_patch_size-ew)//2

        z = dist.sample([n_samples])
        layers = z // (ep.size**2)
        cs = z % (ep.size**2)
        xs, ys = cs//ep.size, cs%ep.size

        for s in range(n_samples):
            for j in range(batch_size):
                layer,x,y = (a[s,j] for a in (layers,xs,ys))
                ep_map[:,layer,x+shift:x+shift+ew,y+shift:y+shift+ew] += lc_batch[j,:,d:w-d,d:w-d]

        if it%10==0 and show:
            pt.subplot(121)
            pt.imshow(ep.mean.detach().cpu().numpy()[0,:3].T)
            
            pt.subplot(122)
            pt.imshow(vis_fn(ep_map[:,0,hw:hw+ep.size,hw:hw+ep.size]).T)
            pt.show()
            
    # wrap around
    ep_map[...,:max_patch_size,:] += ep_map[...,-max_patch_size-1:-1,:]
    ep_map[...,-max_patch_size-1:-1,:] = ep_map[...,:max_patch_size,:]
    ep_map[...,:,:max_patch_size] += ep_map[...,:,-max_patch_size-1:-1]
    ep_map[...,:,-max_patch_size-1:-1] = ep_map[...,:,:max_patch_size]
            
    return ep_map
In [ ]:
def superres(lrmap, max_iter=20):
    eps=0.00000000001
    
    # use the full p(l,c)
    nlcd_mu = util.nlcd_mu
    p_l_c = T.from_numpy(nlcd_mu).float().to(device)

    # init the p(s|c): renormalize priors, then normalize over positions
    p_s_c = (T.from_numpy(lrmap).float() + eps).to(device)
    p_s_c /= (p_s_c.sum(0))
    p_s_c /= p_s_c.sum((1,2,3)).view(-1,1,1,1)
    
    # init the p(l|s) and q(s|l,c)
    p_l_s = (T.rand(p_s_c.shape[1:]+(4,))+10).to(device)
    p_l_s /= p_l_s.sum(3).unsqueeze(3)
    q = T.empty(p_s_c.shape[1:] + p_l_c.shape)

    for it in range(max_iter):
        # E step
        q = T.einsum('exyl,cexy->exycl',p_l_s,p_s_c) + eps
        q /= q.sum((0,1,2))

        # M step
        p_l_s = T.einsum('exycl,cl->exyl',q,p_l_c)+eps
        p_l_s /= p_l_s.sum(3).unsqueeze(3)

    return p_l_s.cpu().numpy()
In [131]:
# build embedding map using high-res labels (lc)

ep_map_hr = label_embed(lc_oh, util.vis_lc)

pt.subplot(121)
pt.imshow(ep.mean.detach().cpu().numpy()[0,:3].T)
pt.subplot(122)
pt.imshow(util.vis_lc(ep_map_hr[:,0,hw:hw+ep.size,hw:hw+ep.size]).T)
pt.show()
In [132]:
# build embedding map using low-res labels (nlcd)

ep_map_lr = label_embed(nlcd_oh, util.vis_nlcd)

pt.subplot(121)
pt.imshow(ep.mean.detach().cpu().numpy()[0,:3].T)
pt.subplot(122)
pt.imshow(util.vis_nlcd(ep_map_lr[:,0,hw:hw+ep.size,hw:hw+ep.size]).T)
pt.show()

# super-resolve it to a high-res map

ep_map_sr = superres(ep_map_lr).T.swapaxes(1,3)

pt.subplot(121)
pt.imshow(ep.mean.detach().cpu().numpy()[0,:3].T)
pt.subplot(122)
pt.imshow(util.vis_lc(ep_map_sr[:,0,hw:hw+ep.size,hw:hw+ep.size]).T)
pt.show()
In [51]:
def segment(ep_map, vis_fn=util.vis_lc):
    
    n_batches = 51
    n_samples = 64

    reconstruction = np.zeros((4,) + img.shape[1:])
    counts = np.zeros(img.shape[1:])+0.000001

    for it in range(n_batches):

        #w = np.random.randint(10,16)*2+1 #size of the patch to compute posterior for
        w = 11
        # making w smaller will make small features more likely to appear in reconstruction
        ew = 11 #size of the center piece of the patch from which to copy labels from ep_map

        batch = np.zeros((batch_size,4,w,w))
        coords = []
        for b in range(batch_size):
            x = np.random.randint(img.shape[1]-w+1)
            y = np.random.randint(img.shape[2]-w+1)
            coords.append((x,y))
            batch[b] = img[:,x:x+w,y:y+w]

        x = T.from_numpy(batch).to(device, T.float)

        e = ep(x) / (w/11)**2

        logits = e.transpose(0,1).reshape(batch_size,-1)
        dist = T.distributions.Categorical(logits=logits.cpu())  

        d = (w-ew)//2
        shift = (max_patch_size-ew)//2

        z = dist.sample([n_samples])
        layers = z // (ep.size**2)
        cs = z % (ep.size**2)
        xs, ys = cs//ep.size, cs%ep.size

        for s in range(n_samples):
            for j in range(batch_size):
                layer,x,y = (a[s,j] for a in (layers,xs,ys))
                cx,cy = coords[j]
                reconstruction[:,cx+d:cx+d+ew,cy+d:cy+d+ew] += ep_map[:,layer,x+shift:x+shift+ew,y+shift:y+shift+ew] 
                counts[cx+d:cx+d+ew,cy+d:cy+d+ew] += 1

        if it%10==0:# and show:
            pt.figure(figsize=(12,4))
            pt.subplot(131);pt.title('image')
            pt.imshow(img[:3].T)
            pt.subplot(132);pt.title('prediction')
            pt.imshow(vis_fn(reconstruction/counts).T)
            pt.subplot(133);pt.title('gt')
            pt.imshow(util.vis_lc(lc_oh).T)
            pt.show()

    return reconstruction/counts
In [150]:
# segment using the hr-derived epitome embedding

rec_hr = segment(ep_map_hr)
In [151]:
# segment using the nlcd+sr-derived epitome embedding

rec_sr = segment(ep_map_sr)
In [52]:
ep_map_rgbi = ep.mean[0].detach().cpu().numpy()
s = max_patch_size // 2
ep_map_rgbi = np.concatenate( [ ep_map_rgbi[:,-s:,:], ep_map_rgbi, ep_map_rgbi[:,:s,:] ], 1 )
ep_map_rgbi = np.concatenate( [ ep_map_rgbi[:,:,-s:], ep_map_rgbi, ep_map_rgbi[:,:,:s] ], 2 )

rec_img = segment(ep_map_rgbi[:,None], vis_fn=lambda x:x[:3])
In [67]:
pt.figure(figsize=(15,5))
pt.subplot(131)
pt.imshow(img[:3].T)
pt.subplot(132)
pt.imshow(rec_img[:3].T)
pt.subplot(133)
pt.imshow((((img-rec_img)**2).sum(0) * (rec_img.sum(0)>0)).T, cmap='Reds', interpolation='none')
Out[67]:
<matplotlib.image.AxesImage at 0x7fa2a404f580>